#!/usr/bin/python
# -*- coding: utf-8 -*-
# vim: set syntax=python

from __future__ import print_function

import sys
import os
import re
import copy
import random
import httplib
import json
import traceback
import sre_constants
from collections import defaultdict
from optparse import OptionParser


def _exec(s, g, e):
    exec(s, g, e)

EXEC_LINE_RE = re.compile('File "<string>", line (\d+)')

def error(s, **kw):
    print('\033[91m%s\033[0m' % s, **kw)
    sys.stdout.flush()


def info(s, **kw):
    print('\033[92m%s\033[0m' % s, **kw)
    sys.stdout.flush()


def warn(s, **kw):
    print('\033[95m%s\033[0m' % s, **kw)
    sys.stdout.flush()


class ParseError(Exception):
    def __init__(self, fpath, line_no, line):
        self.fpath = fpath
        self.line_no = line_no
        self.line = line


# filters

def http_wrap(s, unused):
    parts = s.split('\n')
    i = next((i for i, v in enumerate(parts) if v == ''), len(parts))
    body = '\n'.join(l for l in parts[i+1:])
    hh = parts[:i]
    if not any(l.split(':', 1)[0].title() == 'Content-Length' for l in hh):
        hh.append('Content-Length: %s' % len(body))
    header = '\r\n'.join(l for l in hh)
    return '%s\r\n\r\n%s' % (header, body)


def python_exec(s, e):
    _exec(s, None, e)
    return s


class Block(object):
    EVAL_RE = re.compile('{{(.+?)}}')

    FILTERS = {
        'http': http_wrap,
        'python': python_exec,
    }

    DEFAULT_FILTER = {
        'request': 'http',
        'setup': 'python',
        'teardown': 'python',
    }

    def __init__(self, head, lines):
        head = re.sub('\s', '', head)
        heads = head.split('|')
        self.name = heads[0]
        self.filters = heads[1:]

        # set default filter
        default_flt = self.DEFAULT_FILTER.get(self.name)
        if default_flt and not self.filters:
            self.filters.append(default_flt)

        # strip empty lines
        i, j = 0, len(lines) - 1
        while i <= j and lines[i].strip() == '':
            i += 1
        while i <= j and lines[j].strip() == '':
            j -= 1

        self.raw_lines = [re.sub('\\\\r$', '\r', l.rstrip())
                for l in lines[i:j+1]]

        if self.raw_lines and self.raw_lines[-1] == '__EOF__':
            self.raw_lines[-1] = ''

    def __str__(self):
        if self.filters:
            return '<Block, name=%s, filters=%s>' % \
                    (self.name, '|'.join(self.filters))
        return '<Block, name=%s>' % (self.name)

    def __repr__(self):
        return self.__str__()

    @property
    def raw_content(self):
        return '\n'.join(self.raw_lines)

    @classmethod
    def evaluate(cls, env, content):
        return cls.EVAL_RE.sub(
                    lambda m: str(eval(m.groups()[0], None, env)),
                    content)

    def lines(self, env):
        t = [self.EVAL_RE.sub(
                    lambda m: str(eval(m.groups()[0], None, env)),
                    l) for l in self.raw_lines]
        return t

    def content(self, env):
        return self.evaluate(env, self.raw_content)

    def filter(self, env, raw_content=None):
        content = raw_content if raw_content else self.content(env)
        for flt in self.filters:
            if flt == 'raw':
                continue
            flt_func = env.get(flt) or self.FILTERS.get(flt)
            if not flt_func:
                raise Exception('Can not find filter \'%s\'' % flt)
            try:
                content = flt_func(content, env)
            except Exception as e:
                raise e
        return content


class Case(object):
    SERIES_PAT = re.compile('Test *(\w+)\.\w+ *:')

    def __init__(self, name):
        self.name = name
        self.blocks = {}

    def set_block(self, block):
        self.blocks[block.name] = block

    def get_block(self, name):
        return self.blocks.get(name)

    def __str__(self):
        return '<Case, %s>' % (self.name)

    def __repr__(self):
        return self.__str__()

    @property
    def series(self):
        se = self.SERIES_PAT.findall(self.name)
        if se:
            return se[0]


def parse_file(fpath):
    cases = []
    env_lines = []
    block_lines = []
    block_head = None
    case = None
    block = None

    st_env, st_case, st_block = range(3)
    st = st_env
    line_no = 0

    with open(fpath, 'r') as f:
        for line in f:
            line_no += 1
            if st != st_env and line.strip().startswith('#'):
                continue

            if line.startswith('==='):
                if case:
                    if block_head:
                        block = Block(block_head, block_lines)
                        case.set_block(block)
                        block_head = None
                    cases.append(case)
                case = Case(line[3:].strip())
                st = st_case
            elif line.startswith('---'):
                if st != st_case and st != st_block:
                    raise ParseError(fpath, line_no, line)
                if block_head:
                    block = Block(block_head, block_lines)
                    case.set_block(block)
                block_head = line[3:].strip()
                block_lines = []
                st = st_block
            elif st == st_env:
                env_lines.append(line)
            elif st == st_block:
                block_lines.append(line)
            elif line.strip():
                raise ParseError(fpath, line_no, line)
        if case:
            if block_head:
                block = Block(block_head, block_lines)
                case.set_block(block)
            cases.append(case)

    return env_lines, cases


def gather_files(test_dir):
    test_files = []
    if os.path.isdir(test_dir):
        for d, _, files in os.walk(test_dir):
            test_files.extend(os.path.join(d, f) for f in files
                    if f.endswith('yt'))
    else:
        test_files = [test_dir] if test_dir.endswith('yt') else []

    return test_files


def gather_cases(test_files, filter=None):
    test_cases = []
    for fpath in test_files:
        env = {'sys': sys}
        try:
            env_lines, cases = parse_file(fpath)
            if env_lines:
                _exec(''.join(env_lines), None, env)
        except ParseError as e:
            warn('ParseError in %s:%s, invalid line: %s' % \
                    (e.fpath, e.line_no, e.line))
            exit(1)
        except NameError as e:
            warn('NameError in %s: %s' % (fpath, e.message))
            exit(1)
        except SyntaxError as e:
            warn("SyntaxError in python block %s:%s" % \
                    (fpath, e.lineno))
            exit(1)
        except Exception as e:
            warn('Exception in %s: %s' % (fpath, str(e)))
            exit(1)
        else:
            if filter:
                cs = [c for c in cases if re.findall(filter, c.name)]
                if cs:
                    test_cases.append((fpath, (env, cs)))
            else:
                test_cases.append((fpath, (env, cases)))

    return test_cases


class Server(object):

    ERR_PAT = '.+\[error\]'
    WARN_PAT = '.+\[warn\]'

    def __init__(self, env):
        self.env = env
        self.log_path = env.get('LOG_PATH', '')

    def stop(self):
        os.system(self.env.get('STOP_SERVER', ''))

    def start(self):
        self.stop()
        os.system(self.env.get('PRE_START_SERVER', ''))
        cfg = self.env.get('CONFIG')
        cfg_path = self.env.get('CONFIG_PATH')
        if cfg and cfg_path:
            os.system('cp %s %s' % (cfg, cfg_path))
        os.system(self.env.get('START_SERVER', ''))
        if self.log_path and os.path.exists(self.log_path):
            os.system('> %s' % self.log_path)

    def reload(self):
        os.system(self.env.get('RELOAD_SERVER', ''))

    def grep_log(self, level='error'):
        pat = re.compile(self.ERR_PAT if level == 'error' else self.WARN_PAT)
        logs = []
        with open(self.log_path, 'r') as f:
            logs = [l for l in f if pat.match(l)]
        return logs


def do_http_request(endpoint, body, timeout=5):
    host, port = endpoint.split(':') if ':' in endpoint else (endpoint, 80)
    conn = httplib.HTTPConnection(host, port, timeout=timeout)
    conn.send(body)
    # XXX I really don't want to set private member like this,
    # however, I have no other idea how to getresponse after send
    # directly from socket.
    conn._HTTPConnection__state = 'Request-sent'
    resp = conn.getresponse()
    line = 'HTTP/1.%s %s %s' % (resp.version % 10, resp.status, resp.reason)
    data = [line]
    headers = ('%s: %s' % (k.title(), v) for k, v in resp.getheaders())
    data.extend(headers)
    data.append('')
    resp_body = resp.read()
    conn.close()
    if resp_body:
        data.append(resp_body)
    return '\r\n'.join(data)


class NotMatch(object):
    def __init__(self, exp, got, title=None):
        self.exp = exp
        self.got = got
        self.title = title

    def __str__(self):
        ss = []
        if self.title:
            ss.append('\033[95m%s\033[0m' % self.title)
        ss.extend([
            '\033[95mexpect:\033[0m',
            '\033[92m%s\033[0m' % self.exp,
            '\033[95mbut got:\033[0m',
            '\033[92m%s\033[0m' % self.got,
        ])
        return '\n'.join(ss)

    def __repr__(self):
        return self.__str__()


class ExecError(object):
    def __init__(self, err, lineno, exc):
        self.err = err
        self.lineno = lineno
        self.exc = exc

    def __str__(self):
        return '\033[95m%s in line %s: %s\033[0m' \
            % (self.err, self.lineno, self.exc)

    def __repr__(self):
        return self.__str__()


def re_match(expect, got):
    try:
        pat = re.compile('^%s$' % expect)
        if pat.match(got):
            return True
    except sre_constants.error:
        pass
    return False


class ResponseObj(object):
    def __init__(self, headers, body):
        self.headers = {}
        headers = headers.split('\r\n')
        for h in headers:
            hh = h.split(':', 1)
            if len(hh) == 2:
                self.headers[hh[0].strip()] = hh[1].strip()
            else:
                self.headers[h] = ''
        self.body = body
        try:
            self.json = json.loads(body)
        except:
            self.json = {}


def run_case(srv, env, case):
    def _check_log(level, force=False):
        err_prefix = 'grep %s failed, ' % level

        log_lines = srv.grep_log(level=level)

        blk = case.get_block(level)
        if not blk:
            if force and log_lines:
                return NotMatch([], log_lines, err_prefix)
            return

        lines = [l for l in blk.lines(env) if l.strip()]

        if len(lines) != len(log_lines):
            return NotMatch(lines, log_lines, err_prefix)
        for idx, line in enumerate(lines):
            if not re.match(line, log_lines[idx]):
                return NotMatch(line, log_lines[idx], err_prefix)

    setup_blk = case.get_block('setup')
    if setup_blk:
        setup_blk.filter(env)

    teardown_blk = case.get_block('teardown')

    request_blk = case.get_block('request')
    if request_blk:
        content = request_blk.filter(env)
        endpoint = env.get('ENDPOINT', '127.0.0.1')
        timeout = env.get('TIMEOUT', 5)

        resp = do_http_request(endpoint, content, timeout=timeout)

        if env.get('DEBUG', False):
            info('request %s' % ('-' * 50))
            info(content)
            info('response %s' % ('-' * 50))
            info(resp)

        sep = resp.find('\r\n\r\n')
        resp_headers = resp[:sep] if sep > 0 else resp
        resp_body = '' if sep < 0 else resp[sep+4:]

        env['__resp'] = ResponseObj(resp_headers, resp_body)

        response_blk = case.get_block('response')
        if response_blk:
            expect = [l.strip() for l in response_blk.lines(env)]
            got = response_blk.filter(env, resp_headers).split('\r\n')
            if env.get('DEBUG', False):
                info('expect %s' % ('-' * 50))
                info('\n'.join(expect))

            for line in expect:
                if line in got:
                    continue
                # if not in got, try re
                try:
                    pat = re.compile('^%s$' % line)
                    line_matched = False
                    for l in got:
                        pat_match = pat.match(l)
                        if pat_match:
                            env_arg = pat_match.groupdict()
                            if env_arg:
                                env.update(env_arg)
                            line_matched = True
                            break
                    if line_matched:
                        continue
                except sre_constants.error:
                    pass

                if teardown_blk:
                    teardown_blk.filter(env)
                return NotMatch('\n'.join(expect), '\n'.join(got))

        response_body_blk = case.get_block('response_body')
        if response_body_blk:
            expect = response_body_blk.content(env)
            got = response_body_blk.filter(env, resp_body)

            if expect != got:
                if teardown_blk:
                    teardown_blk.filter(env)
                return NotMatch(expect, got)

        response_match_blk = case.get_block('response_match')
        if response_match_blk:
            expect = response_match_blk.content(env)
            got = response_match_blk.filter(env, resp_body)
            if not re_match(expect, got):
                if teardown_blk:
                    teardown_blk.filter(env)
                return NotMatch(expect, got)

        response_eval_blk = case.get_block('response_eval')
        if response_eval_blk:
            expect = response_eval_blk.content(env)
            try:
                _exec(expect, None, env)
            except Exception as e:
                if teardown_blk:
                    teardown_blk.filter(env)
                lineno = EXEC_LINE_RE.findall(traceback.format_exc())
                if lineno and lineno[0].isdigit() and int(lineno[0]) > 0:
                    lineno = int(lineno[0])
                    exc = expect.split('\n')[lineno - 1]
                    return ExecError(type(e), lineno, exc)
                return ExecError(type(e), 'unknown', expect)

    if teardown_blk:
        teardown_blk.filter(env)

    err = _check_log('error', force=True)
    if err:
        return err
    err = _check_log('warn')
    if err:
        return err


def shuffle_cases(cases):
    indexes = range(len(cases))
    series = defaultdict(list)
    single = []
    for case in cases:
        se = case.series
        if se is not None:
            series[se].append(case)
        else:
            single.append(case)

    randed = []
    series_items = series.items()
    random.shuffle(series_items)
    for se, cas in series_items:
        pos = sorted(random.sample(indexes, len(cas)))
        for p, c in zip(pos, cas):
            randed.append((p, c))
            indexes.remove(p)
    random.shuffle(indexes)
    for p, c in zip(indexes, single):
        randed.append((p, c))
    return [c for _, c in sorted(randed)]


def main():
    parser = OptionParser()
    parser.add_option('-k', '--filter', dest='filter',
            help='file pattern to run')
    parser.add_option('-d', action='store_true', dest='debug',
            help='print request and response')
    parser.add_option('-r', action='store_true', dest='random',
            help='run test cases in random order')
    parser.add_option('-n', action='store_true', dest='name',
            help='print test case names')
    options, args = parser.parse_args()

    test_dir = args[0] if args else 'tests'
    sys.path.append(test_dir.split(os.path.sep)[0])

    test_files = gather_files(test_dir)
    test_cases = gather_cases(test_files, options.filter)

    total_cases = sum(len(c) for f, (e, c) in test_cases)
    info('%s test case(s) gathered.' % total_cases)

    failed_cases = []
    for fpath, (file_env, cases) in test_cases:
        if options.random or file_env.get('RANDOM', False):
            cases = shuffle_cases(cases)

        if cases and options.name:
            for c in cases:
                info(c.name)
            info('\n')

        case_num = len(cases)
        num_width = len(str(case_num))
        info('%s\t...............\t0/%s%s' % \
                (fpath, case_num, '\b' * (num_width + 2)), end='')

        series_env = {}

        for idx, case in enumerate(cases):
            back_wdith = num_width + 1 + len(str(idx + 1))
            info('%s/%s%s' % (idx + 1, case_num, '\b' * back_wdith),
                    end='')

            # prepare env for test series
            series = case.series
            if series is not None:
                if series in series_env:
                    case_env = series_env[series]
                else:
                    case_env = copy.copy(file_env)
                    series_env[series] = case_env
            else:
                case_env = copy.copy(file_env)

            if options.debug:
                case_env['DEBUG'] = True

            if '__resp' in case_env:
                case_env['__last_resp'] = case_env.pop('__resp')

            srv = Server(case_env)
            srv.start()
            try:
                err = run_case(srv, case_env, case)
                if err:
                    error('%s:%s, %s' % (fpath, case.name, err))
                    failed_cases.append((fpath, case.name, err))
            except Exception as e:
                error(str(e))
                failed_cases.append((fpath, case.name, str(e)))
            finally:
                srv.stop()
        info('\n', end='')
    info('=' * 50)
    if not failed_cases:
        info('all test case(s) passed.')
        exit(0)
    else:
        error('%s/%s test case(s) failed.' % (len(failed_cases), total_cases))
        for fpath, name, err in failed_cases:
            error('%s:%s, %s' % (fpath, name, err))
        exit(1)


if __name__ == '__main__':
    main()
